!--------------------------------------------------------------------------------------------------!
! Copyright (C) by the DBCSR developers group - All rights reserved                                !
! This file is part of the DBCSR library.                                                          !
!                                                                                                  !
! For information on the license, see the LICENSE file.                                            !
! For further information please visit https://dbcsr.cp2k.org                                      !
! SPDX-License-Identifier: GPL-2.0+                                                                !
!--------------------------------------------------------------------------------------------------!

PROGRAM dbcsr_tas_unittest

   !! Unit testing for tall-and-skinny matrices
   USE dbcsr_api, ONLY: dbcsr_finalize_lib, &
                        dbcsr_init_lib, &
                        dbcsr_print_statistics
   USE dbcsr_tas_base, ONLY: dbcsr_tas_destroy, &
                             dbcsr_tas_info, &
                             dbcsr_tas_nblkcols_total, &
                             dbcsr_tas_nblkrows_total, &
                             dbcsr_tas_create
   USE dbcsr_tas_types, ONLY: dbcsr_tas_type
   USE dbcsr_tas_test, ONLY: dbcsr_tas_random_bsizes, &
                             dbcsr_tas_setup_test_matrix, &
                             dbcsr_tas_test_mm, &
                             dbcsr_tas_reset_randmat_seed
   USE dbcsr_kinds, ONLY: int_8, &
                          real_8
   USE dbcsr_machine, ONLY: default_output_unit
   USE dbcsr_mpiwrap, ONLY: mp_comm_free, &
                            mp_environ, &
                            mp_world_finalize, &
                            mp_world_init, mp_comm_type
   USE dbcsr_tas_io, ONLY: dbcsr_tas_write_split_info
#include "base/dbcsr_base_uses.f90"

   IMPLICIT NONE

   INTEGER(KIND=int_8), PARAMETER :: m = 100, k = 20, n = 10
   TYPE(dbcsr_tas_type)             :: A, B, C, At, Bt, Ct, A_out, B_out, C_out, At_out, Bt_out, Ct_out
   INTEGER, DIMENSION(m)          :: bsize_m
   INTEGER, DIMENSION(n)          :: bsize_n
   INTEGER, DIMENSION(k)          :: bsize_k
   REAL(KIND=real_8), PARAMETER   :: sparsity = 0.1
   INTEGER                        :: numnodes, mynode, io_unit
   TYPE(mp_comm_type)             :: mp_comm, mp_comm_A, mp_comm_At, mp_comm_B, mp_comm_Bt, mp_comm_C, mp_comm_Ct
   REAL(KIND=real_8), PARAMETER   :: filter_eps = 1.0E-08

   CALL mp_world_init(mp_comm)

   CALL mp_environ(numnodes, mynode, mp_comm)

   io_unit = -1
   IF (mynode .EQ. 0) io_unit = default_output_unit

   CALL dbcsr_init_lib(mp_comm%get_handle(), io_unit)

   CALL dbcsr_tas_reset_randmat_seed()

   CALL dbcsr_tas_random_bsizes([13, 8, 5, 25, 12], 2, bsize_m)
   CALL dbcsr_tas_random_bsizes([3, 78, 33, 12, 3, 15], 1, bsize_n)
   CALL dbcsr_tas_random_bsizes([9, 64, 23, 2], 3, bsize_k)

   CALL mp_environ(numnodes, mynode, mp_comm)

   CALL dbcsr_tas_setup_test_matrix(A, mp_comm_A, mp_comm, m, k, bsize_m, bsize_k, [5, 1], "A", sparsity)
   CALL dbcsr_tas_setup_test_matrix(At, mp_comm_At, mp_comm, k, m, bsize_k, bsize_m, [3, 8], "A^t", sparsity)
   CALL dbcsr_tas_setup_test_matrix(B, mp_comm_B, mp_comm, n, m, bsize_n, bsize_m, [3, 2], "B", sparsity)
   CALL dbcsr_tas_setup_test_matrix(Bt, mp_comm_Bt, mp_comm, m, n, bsize_m, bsize_n, [1, 3], "B^t", sparsity)
   CALL dbcsr_tas_setup_test_matrix(C, mp_comm_C, mp_comm, k, n, bsize_k, bsize_n, [5, 7], "C", sparsity)
   CALL dbcsr_tas_setup_test_matrix(Ct, mp_comm_Ct, mp_comm, n, k, bsize_n, bsize_k, [1, 1], "C^t", sparsity)

   CALL dbcsr_tas_create(A, A_out)
   CALL dbcsr_tas_create(At, At_out)
   CALL dbcsr_tas_create(B, B_out)
   CALL dbcsr_tas_create(Bt, Bt_out)
   CALL dbcsr_tas_create(C, C_out)
   CALL dbcsr_tas_create(Ct, Ct_out)

   IF (mynode == 0) WRITE (io_unit, '(A)') "DBCSR TALL-AND-SKINNY MATRICES"
   IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      TRIM(A%matrix%name), dbcsr_tas_nblkrows_total(A), 'X', dbcsr_tas_nblkcols_total(A)
   CALL dbcsr_tas_write_split_info(dbcsr_tas_info(A), io_unit, name="A")
   IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      TRIM(At%matrix%name), dbcsr_tas_nblkrows_total(At), 'X', dbcsr_tas_nblkcols_total(At)
   CALL dbcsr_tas_write_split_info(dbcsr_tas_info(At), io_unit, name="At")
   IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      TRIM(B%matrix%name), dbcsr_tas_nblkrows_total(B), 'X', dbcsr_tas_nblkcols_total(B)
   CALL dbcsr_tas_write_split_info(dbcsr_tas_info(B), io_unit, name="B")
   IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      TRIM(Bt%matrix%name), dbcsr_tas_nblkrows_total(Bt), 'X', dbcsr_tas_nblkcols_total(Bt)
   CALL dbcsr_tas_write_split_info(dbcsr_tas_info(Bt), io_unit, name="Bt")
   IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      TRIM(C%matrix%name), dbcsr_tas_nblkrows_total(C), 'X', dbcsr_tas_nblkcols_total(C)
   CALL dbcsr_tas_write_split_info(dbcsr_tas_info(C), io_unit, name="C")
   IF (mynode == 0) WRITE (io_unit, '(1X, A, 1X, A, I10, 1X, A, 1X, I10)') "Split info for matrix", &
      TRIM(Ct%matrix%name), dbcsr_tas_nblkrows_total(Ct), 'X', dbcsr_tas_nblkcols_total(Ct)
   CALL dbcsr_tas_write_split_info(dbcsr_tas_info(Ct), io_unit, name="Ct")

   CALL dbcsr_tas_test_mm('N', 'N', 'N', B, A, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'N', 'N', Bt, A, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('N', 'T', 'N', B, At, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'T', 'N', Bt, At, Ct_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('N', 'N', 'T', B, A, C_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'N', 'T', Bt, A, C_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('N', 'T', 'T', B, At, C_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'T', 'T', Bt, At, C_out, unit_nr=io_unit, filter_eps=filter_eps)

   CALL dbcsr_tas_test_mm('N', 'N', 'N', A, C, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'N', 'N', At, C, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('N', 'T', 'N', A, Ct, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'T', 'N', At, Ct, Bt_out, unit_nr=io_unit, filter_eps=filter_eps)

   CALL dbcsr_tas_test_mm('N', 'N', 'T', A, C, B_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'N', 'T', At, C, B_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('N', 'T', 'T', A, Ct, B_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'T', 'T', At, Ct, B_out, unit_nr=io_unit, filter_eps=filter_eps)

   CALL dbcsr_tas_test_mm('N', 'N', 'N', C, B, At_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'N', 'N', Ct, B, At_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('N', 'T', 'N', C, Bt, At_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'T', 'N', Ct, Bt, At_out, unit_nr=io_unit, filter_eps=filter_eps)

   CALL dbcsr_tas_test_mm('N', 'N', 'T', C, B, A_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'N', 'T', Ct, B, A_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('N', 'T', 'T', C, Bt, A_out, unit_nr=io_unit, filter_eps=filter_eps)
   CALL dbcsr_tas_test_mm('T', 'T', 'T', Ct, Bt, A_out, unit_nr=io_unit, filter_eps=filter_eps)

   CALL dbcsr_tas_destroy(A)
   CALL dbcsr_tas_destroy(At)
   CALL dbcsr_tas_destroy(B)
   CALL dbcsr_tas_destroy(Bt)
   CALL dbcsr_tas_destroy(C)
   CALL dbcsr_tas_destroy(Ct)
   CALL dbcsr_tas_destroy(A_out)
   CALL dbcsr_tas_destroy(At_out)
   CALL dbcsr_tas_destroy(B_out)
   CALL dbcsr_tas_destroy(Bt_out)
   CALL dbcsr_tas_destroy(C_out)
   CALL dbcsr_tas_destroy(Ct_out)

   CALL mp_comm_free(mp_comm_A)
   CALL mp_comm_free(mp_comm_At)
   CALL mp_comm_free(mp_comm_B)
   CALL mp_comm_free(mp_comm_Bt)
   CALL mp_comm_free(mp_comm_C)
   CALL mp_comm_free(mp_comm_Ct)

   call dbcsr_print_statistics(.true.)
   CALL dbcsr_finalize_lib()

   CALL mp_world_finalize()

END PROGRAM
